#include <os/debug.h>
#include <os/driver.h>
#include <os/memcache.h>
#include <os/hardirq.h>
#include <os/dma.h>
#include <os/debug.h>
#include <os/mutexlock.h>
#include <os/initcall.h>
#include <os/schedule.h>
#include <lib/stdarg.h>
#include <lib/string.h>
#include <lib/type.h>
#include <lib/errno.h>
#include <lib/unistd.h>
#include <sys/ioctl.h>
#include <sys/res.h>
#include <driver/sb16.h>

static void Sb16DspWrite(uint8_t value);
static uint8_t Sb16DspRead();
static void Sb16SetRate(uint16_t hz);
static void Sb16Request(device_extension_t *extension);
static void Sb16SetVolume(device_extension_t *extension, uint8_t l_vol, uint8_t r_vol);
static void Sb16EnSpeack();
static int Sb16Init(device_extension_t *extension);
static void Sb16DmaStart(dma_region_t *dma_region, uint32_t length);

static void Sb16DspWrite(uint8_t value)
{
    while (In8((uint16_t)DSP_WRITE) & 0x80)
        ;
    Out8((uint16_t)DSP_WRITE, value);
}

static uint8_t Sb16DspRead()
{
    while (!(In8((uint16_t)DSP_STATUS) & 0x80))
        ;
    return In8(DSP_READ);
}

static void Sb16SetRate(uint16_t hz)
{
    Sb16DspWrite((uint8_t)DSP_SET_RATE);
    Sb16DspWrite((uint8_t)(hz >> 8));
    Sb16DspWrite((uint8_t)(hz & 0xff));
}

// set volume
static void Sb16SetVolume(device_extension_t *extension, uint8_t l_vol, uint8_t r_vol)
{
    Out8((uint16_t)DSP_MIXED, (uint8_t)DSP_MASTER_VOLUME);
    Out8((uint16_t)DSP_MIXED_DATA, ((uint8_t)((l_vol & 0xF) << 4) | (r_vol & 0xF)));
}

static void Sb16EnSpeack()
{
    Sb16DspWrite(DSP_ENABLE_SPEAKE);
}

static int Sb16Init(device_extension_t *extension)
{
    // reset
    Out8((uint16_t)DSP_RESET, (uint8_t)1);
    CpuDoDelay(1);
    Out8((uint16_t)DSP_RESET, (uint8_t)0);
    int data = Sb16DspRead();
    if (data != 0xaa)
    {
        KPrint(PRINT_ERR "sb16: sb16 not ready!\n");
        return -1;
    }
    // get version info
    Sb16DspWrite(DSP_GET_VERSION);
    int major_version = Sb16DspRead();
    int minor_version = Sb16DspRead();
    extension->major_version = major_version;
    extension->minor_version = minor_version;
    KPrint("sb16: found version %d.%d\n", extension->major_version, minor_version);

    // enable speack
    Sb16EnSpeack();

    // set output rate
    const int rate = 20000;
    Sb16SetRate(rate);

    // set master volume
    extension->volume = VOL_MID;
    /*Sb16SetVolume(extension, VOL_MID, VOL_MID);*/

   
    KPrint("[sb16] enable sb16 driver\n");
    return 0;
}

static void Sb16Request(device_extension_t *extension)
{
    dma_region_t *dma_region = &extension->dma_region[extension->index_r];
    uint8_t mode = DSP_PLAY_SIGNED | DSP_PLAY_STEREO;
    int length = extension->date_len[extension->index_r];
    if (extension->index_r == extension->index_w)
        return;

    // dma start
    Sb16DmaStart(dma_region, length);
    uint16_t count = length / sizeof(uint16_t);
    if (mode & DSP_PLAY_STEREO)
        count /= 2;
    count = count - 1;

    // play start
    Sb16DspWrite(DSP_PLAY_16BIT);
    Sb16DspWrite(mode);
    Sb16DspWrite((uint8_t)count);
    Sb16DspWrite((uint8_t)(count >> 8));
    KPrint("sb16: [DMA] %x sample count %d\n", dma_region->v, count);
}

// start dma
static void Sb16DmaStart(dma_region_t *dma_region, uint32_t length)
{
    uint32_t addr = dma_region->p.addr;
    uint8_t channel = 5; // 16bits use DMA channel 5
    uint8_t mode = 0;

    // disable the DMA channel
    Out8((uint16_t)0xd4, (uint8_t)(4 + (channel & 0x3)));

    // clear the byte pointer
    Out8((uint16_t)0xd8, (uint8_t)0);

    // write DMA mode
    Out8((uint16_t)0xd6, (channel & 3) | mode);

    // write offset
    uint16_t off = (addr/2) & 0xffff;
    Out8((uint16_t)0xc4, (uint8_t)off);
    Out8((uint16_t)0xc4, (uint8_t)(off >> 8));

    // write transfer lenght
    Out8((uint16_t)0xc6, (uint8_t)(length - 1));
    Out8((uint16_t)0xc6, (uint8_t)(length - 1) >> 8);

    // write buffer
    Out8((uint16_t)0x8b, addr >> 16); 

    // enable DMA channel
    Out8((uint16_t)0xd4, (channel & 0x3));
}

static int Sb16Handler(irqno_t irq, void *data)
{
    device_extension_t *extension = (device_extension_t *)data;

    KPrint("sb16 interrupt\n");

    // stop sound output
    Sb16DspWrite(DSP_PAUSE_16BIT);

    // 8 bit interrrupt
    In8(DSP_STATUS);
    if (extension->major_version >= 4)
        In8(DSP_R_ACK); // 16 bit interrupt
    extension->index_r = (extension->index_r + 1) % DMA_COUNT;

    WaitQueueWakeupAll(&extension->waiter);

    if (extension->index_r != extension->index_w)
    {
        Sb16Request(extension);
    }
    return 0;
}

static size_t __Sb16Write(device_extension_t *extension, const uint8_t *data, size_t length)
{
    dma_region_t *dma_region = &extension->dma_region[extension->index_w];

    if (length > dma_region->p.size)
    {
        KPrint("[sb16] write length err\n");
        return -ENOSPC;
    }

    while (((extension->index_w + 1) % DMA_COUNT) == extension->index_r)
    {
        WaitQueueAdd(&extension->waiter, cur_task);
        TaskBlock(TASK_BLOCKED);
    }

    extension->date_len[extension->index_w] = length;
    memcpy((void *)dma_region->v, (void *)data, length);

    if (extension->index_w == extension->index_r) // new data
    {
        extension->index_w = (extension->index_w + 1) % DMA_COUNT; // next block
        Sb16Request(extension);                                    // play sound
    }
    else
    {
        extension->index_w = (extension->index_w + 1) % DMA_COUNT; // next block
    }
    KPrint("[sb16] write %dbytes\n", length);
    return length;
}

static iostatus_t Sb16Write(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;
    int len = __Sb16Write(device->device_extension, ioreq->user_buff, ioreq->parame.write.len);
    if (len < 0)
    {
        status = IO_FAILED;
        KPrint("%s: sb16 write failed!\n");
    }

    ioreq->io_status.status = status;
    ioreq->io_status.info = len;
    IoCompleteRequest(ioreq);
    return status;
}

static iostatus_t Sb16Open(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;
    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

static iostatus_t Sb16Close(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;
    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

static iostatus_t Sb16Enter(driver_object_t *driver)
{
    iostatus_t status;
    device_object_t *device;
    device_extension_t *extension;

    status = IoCreateDevice(driver, sizeof(device_extension_t), DEVICE_NAME, DEVICE_TYPE_SOUND, &device);
    if (status != IO_SUCCESS)
    {
        KPrint(PRINT_ERR "%s:create device failed!\n", __func__);
        return status;
    }
    device->flags = 0;
    extension = device->device_extension;
    extension->index_r = extension->index_w = 0; // init index

    // init all dma region
    for (int i = 0; i < DMA_COUNT; i++)
    {
        extension->dma_region[i].p.size = PAGE_SIZE * 16; // 32kb
        extension->dma_region[i].p.alignment = 0x1000;
        extension->dma_region[i].flags = DMA_REGION_SPECIAL;

        if (DmaAllocBuffer(&extension->dma_region[i]) < 0)
        {
            KPrint(PRINT_ERR "%s: alloc dma buffer faild!\n", __func__);
            IoDeleteDevice(device);
            return IO_FAILED;
        }
    }
    WaitQueueInit(&extension->waiter);

    // init sb16
    if (Sb16Init(extension) < 0)
    {
        KPrint("sb16: init sb16 sound device failed!\n");
        IoDeleteDevice(device);
        status = IO_FAILED;
        return status;
    }

    // irq register
    IrqRegister(IRQ5_PARALLEL2, Sb16Handler, IRQ_DISABLE, "IRQ5", DRIVER_NAME, extension);
    return status;
}

static iostatus_t Sb16Exit(driver_object_t *driver)
{
    device_object_t *device, *next;

    list_traversal_all_owner_to_next_safe(device, next, &driver->device_list, list)
    {
        list_del_init(&device->list);
    }
    string_del(&driver->name);
    return IO_SUCCESS;
}

static iostatus_t Sb16DevCtl(device_object_t *device, io_request_t *ioreq)
{
    device_extension_t *extension = device->device_extension;
    uint32_t cmd = ioreq->parame.devctl.code;
    uint32_t arg = ioreq->parame.devctl.arg;
    iostatus_t status = IO_SUCCESS;

    switch (cmd)
    {
    case SOUNDIO_VOLADD:
        Sb16SetVolume(extension, extension->volume++, extension->volume++);
        break;
    case SOUNDIO_VOLDEC:
        Sb16SetVolume(extension, extension->volume--, extension->volume--);
        break;
    case SOUNDIO_STOP:
        Sb16DspWrite(DSP_STOP_16BIT);
        break;
    case SOUNDIO_PAUSE:
        Sb16DspWrite(DSP_PAUSE_16BIT);
        break;
    case SOUNDIO_PLAY:
        Sb16DspWrite(DSP_PLAY_16BIT);
        break;
    case SOUNDIO_GETVOL:
        *(uint8_t *)arg = extension->volume;
        break;
    case SOUNDIO_SETVOL:
        extension->volume = *(uint8_t *)arg;
        Sb16SetVolume(extension, *(uint8_t *)arg, *(uint8_t *)arg);
        break;
    default:
        status = IO_FAILED;
        break;
    }
    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

static iostatus_t Sb16DriverFunc(driver_object_t *driver)
{
    iostatus_t status = IO_SUCCESS;

    driver->driver_enter = Sb16Enter;
    driver->driver_exit = Sb16Exit;

    driver->dispatch_fun[IOREQ_OPEN] = Sb16Open;
    driver->dispatch_fun[IOREQ_CLOSE] = Sb16Close;

    driver->dispatch_fun[IOREQ_WRITE] = Sb16Write;
    driver->dispatch_fun[IOREQ_DEVCTL] = Sb16DevCtl;

    string_new(&driver->name, DRIVER_NAME, DRIVER_NAME_LEN);

    return status;
}

static __init void Sb16DriverEntry(void)
{
    KPrint("[driver] create sb16 driver\n");
    if (DriverObjectCreate(Sb16DriverFunc) < 0)
    {
        KPrint(PRINT_ERR "[driver]:%s create driver %s failed\n", __func__, DRIVER_NAME);
    }
}
driver_initcall(Sb16DriverEntry);
